{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4620bc8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip3 install gpt-2-simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "477a8024",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from gpt_2_simple.src import model as gpt2_model, encoder\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9528fd8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = '345m-hparams.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cc10d606",
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams = gpt2_model.default_hparams()\n",
    "with open(params) as f:\n",
    "    hparams.override_from_dict(json.load(f))\n",
    "\n",
    "with open('encoder.json', 'r') as f:\n",
    "    en = json.load(f)\n",
    "with open('vocab.bpe', 'r', encoding=\"utf-8\") as f:\n",
    "    bpe_data = f.read()\n",
    "    \n",
    "bpe_merges = [\n",
    "    tuple(merge_str.split()) for merge_str in bpe_data.split('\\n')[1:-1]\n",
    "]\n",
    "enc_malay = encoder.Encoder(encoder=en, bpe_merges=bpe_merges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "04232aaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "def top_k_logits(logits, k):\n",
    "\n",
    "    def _top_k():\n",
    "        values, _ = tf.nn.top_k(logits, k=k)\n",
    "        min_values = values[:, -1, tf.newaxis]\n",
    "        return tf.where(\n",
    "            logits < min_values,\n",
    "            tf.ones_like(logits, dtype=logits.dtype) * -1e10,\n",
    "            logits,\n",
    "        )\n",
    "\n",
    "    return tf.cond(\n",
    "        pred=tf.equal(k, 0),\n",
    "        true_fn=lambda: logits,\n",
    "        false_fn=lambda: _top_k(),\n",
    "    )\n",
    "\n",
    "\n",
    "def top_p_logits(logits, p):\n",
    "    with tf.variable_scope('top_p_logits'):\n",
    "        logits_sort = tf.sort(logits, direction='DESCENDING')\n",
    "        probs_sort = tf.nn.softmax(logits_sort)\n",
    "        probs_sums = tf.cumsum(probs_sort, axis=1, exclusive=True)\n",
    "        logits_masked = tf.where(\n",
    "            probs_sums < p, logits_sort, tf.ones_like(logits_sort) * 1000\n",
    "        )\n",
    "        min_logits = tf.reduce_min(\n",
    "            input_tensor=logits_masked, axis=1, keepdims=True\n",
    "        )\n",
    "        return tf.where(\n",
    "            logits < min_logits,\n",
    "            tf.ones_like(logits, dtype=logits.dtype) * -1e10,\n",
    "            logits,\n",
    "        )\n",
    "\n",
    "\n",
    "def sample_sequence(\n",
    "    hparams,\n",
    "    length,\n",
    "    start_token=None,\n",
    "    batch_size=None,\n",
    "    context=None,\n",
    "    temperature=1,\n",
    "    top_k=0,\n",
    "    top_p=0.0,\n",
    "):\n",
    "    if start_token is None:\n",
    "        assert (\n",
    "            context is not None\n",
    "        ), 'Specify exactly one of start_token and context!'\n",
    "    else:\n",
    "        assert (\n",
    "            context is None\n",
    "        ), 'Specify exactly one of start_token and context!'\n",
    "        context = tf.fill([batch_size, 1], start_token)\n",
    "\n",
    "    def step(hparams, tokens, past=None):\n",
    "        lm_output = gpt2_model.model(\n",
    "            hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE\n",
    "        )\n",
    "\n",
    "        logits = lm_output['logits'][:, :, : hparams.n_vocab]\n",
    "        presents = lm_output['present']\n",
    "        presents.set_shape(\n",
    "            gpt2_model.past_shape(hparams=hparams, batch_size=None)\n",
    "        )\n",
    "        return {'logits': logits, 'presents': presents}\n",
    "\n",
    "    with tf.name_scope('sample_sequence'):\n",
    "        lens = tf.constant(0, dtype=tf.int32)\n",
    "        context_output = step(hparams, context[:, :-1])\n",
    "        \n",
    "        def apply_temp(logits_BxN, temperature):\n",
    "            logits_shape = tf.shape(logits_BxN)\n",
    "            uniform_noise_BxN = tf.random_uniform(logits_shape)\n",
    "            logits_BxN += -tf.log(-tf.log(uniform_noise_BxN)) * temperature\n",
    "            return logits_BxN\n",
    "\n",
    "        def body(past, prev, output, lens):\n",
    "            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)\n",
    "            logits = next_outputs['logits'][:, -1, :]  \n",
    "            logits = tf.cond(\n",
    "                temperature > 0,\n",
    "                lambda: apply_temp(logits, temperature),\n",
    "                lambda: logits,\n",
    "            )\n",
    "            logits = tf.cond(top_p > 0.0, lambda: top_p_logits(logits, p=top_p),\n",
    "                             lambda: top_k_logits(logits, k=top_k))\n",
    "            samples = tf.random.categorical(\n",
    "                logits, num_samples=1, dtype=tf.int32\n",
    "            )\n",
    "            return [\n",
    "                tf.concat([past, next_outputs['presents']], axis=-2),\n",
    "                tf.squeeze(samples, axis=[1]),\n",
    "                tf.concat([output, samples], axis=1),\n",
    "                lens + 1\n",
    "            ]\n",
    "\n",
    "        def cond(past, prev, output, lens):\n",
    "            return tf.less(lens, length)\n",
    "\n",
    "        _, _, tokens, _ = tf.while_loop(\n",
    "            cond=cond,\n",
    "            body=body,\n",
    "            loop_vars=[context_output['presents'], context[:, -1], context, lens],\n",
    "            shape_invariants=[\n",
    "                tf.TensorShape(\n",
    "                    gpt2_model.past_shape(\n",
    "                        hparams=hparams, batch_size=None\n",
    "                    )\n",
    "                ),\n",
    "                tf.TensorShape([None]),\n",
    "                tf.TensorShape([None, None]),\n",
    "                lens.get_shape(),\n",
    "            ],\n",
    "            back_prop=False,\n",
    "        )\n",
    "\n",
    "        return tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7d1191da",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self, hparams, encoder, **kwargs\n",
    "    ):\n",
    "        self._encoder = encoder\n",
    "        self._X = tf.placeholder(tf.int32, [1, None], name = 'X')\n",
    "        self._temperature = tf.placeholder(tf.float32, None, name = 'temp')\n",
    "        self._top_k = tf.placeholder(tf.int32, None, name = 'top_k')\n",
    "        self._top_p = tf.placeholder(tf.float32, None, name = 'top_p')\n",
    "        self._maxlen = tf.placeholder(tf.int32, None, name = 'maxlen')\n",
    "        self._n_samples = tf.placeholder(tf.int32, None, name = 'n_samples')\n",
    "        x = tf.tile(self._X, [self._n_samples, 1])\n",
    "        self._model = sample_sequence(\n",
    "            hparams=hparams,\n",
    "            length=self._maxlen,\n",
    "            context=x,\n",
    "            batch_size=self._n_samples,\n",
    "            temperature=self._temperature,\n",
    "            top_k=self._top_k,\n",
    "            top_p=self._top_p,\n",
    "        )\n",
    "        self.output = tf.identity(self._model, name = 'output')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5c4a7b0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-5-31e2ecd21534>:27: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "model = Model(\n",
    "    hparams, enc_malay\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cd5bfbec",
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f751b579",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from gs://mesolitica-tpu-general/gpt2-345m/model.ckpt-380000\n"
     ]
    }
   ],
   "source": [
    "var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)\n",
    "saver = tf.train.Saver(var_list = var_list)\n",
    "saver.restore(sess, 'gs://mesolitica-tpu-general/gpt2-345m/model.ckpt-380000')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2442dcff",
   "metadata": {},
   "outputs": [],
   "source": [
    "string = 'mahathir dan najib razak sangat sayangkan anwar ibrahim'\n",
    "encoded = enc_malay.encode(string)\n",
    "len(encoded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "542f0385",
   "metadata": {},
   "outputs": [],
   "source": [
    "o = sess.run(model._model, feed_dict = {model._X: [encoded],\n",
    "                                  model._temperature: 0.0,\n",
    "                                  model._top_k: 0,\n",
    "                                  model._top_p: 0.7,\n",
    "                                  model._maxlen: 100,\n",
    "                                  model._n_samples: 1})\n",
    "o.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20b5e629",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(enc_malay.decode(o[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58851d38",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(enc_malay.decode(o[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0ec9b89d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'gpt2-345m/model.ckpt'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver()\n",
    "saver.save(sess, 'gpt2-345m/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "71b58d61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['X',\n",
       " 'temp',\n",
       " 'top_k',\n",
       " 'top_p',\n",
       " 'maxlen',\n",
       " 'n_samples',\n",
       " 'model/wpe',\n",
       " 'model/wte',\n",
       " 'model/h0/ln_1/g',\n",
       " 'model/h0/ln_1/b',\n",
       " 'model/h0/attn/c_attn/w',\n",
       " 'model/h0/attn/c_attn/b',\n",
       " 'model/h0/attn/c_proj/w',\n",
       " 'model/h0/attn/c_proj/b',\n",
       " 'model/h0/ln_2/g',\n",
       " 'model/h0/ln_2/b',\n",
       " 'model/h0/mlp/c_fc/w',\n",
       " 'model/h0/mlp/c_fc/b',\n",
       " 'model/h0/mlp/c_proj/w',\n",
       " 'model/h0/mlp/c_proj/b',\n",
       " 'model/h1/ln_1/g',\n",
       " 'model/h1/ln_1/b',\n",
       " 'model/h1/attn/c_attn/w',\n",
       " 'model/h1/attn/c_attn/b',\n",
       " 'model/h1/attn/c_proj/w',\n",
       " 'model/h1/attn/c_proj/b',\n",
       " 'model/h1/ln_2/g',\n",
       " 'model/h1/ln_2/b',\n",
       " 'model/h1/mlp/c_fc/w',\n",
       " 'model/h1/mlp/c_fc/b',\n",
       " 'model/h1/mlp/c_proj/w',\n",
       " 'model/h1/mlp/c_proj/b',\n",
       " 'model/h2/ln_1/g',\n",
       " 'model/h2/ln_1/b',\n",
       " 'model/h2/attn/c_attn/w',\n",
       " 'model/h2/attn/c_attn/b',\n",
       " 'model/h2/attn/c_proj/w',\n",
       " 'model/h2/attn/c_proj/b',\n",
       " 'model/h2/ln_2/g',\n",
       " 'model/h2/ln_2/b',\n",
       " 'model/h2/mlp/c_fc/w',\n",
       " 'model/h2/mlp/c_fc/b',\n",
       " 'model/h2/mlp/c_proj/w',\n",
       " 'model/h2/mlp/c_proj/b',\n",
       " 'model/h3/ln_1/g',\n",
       " 'model/h3/ln_1/b',\n",
       " 'model/h3/attn/c_attn/w',\n",
       " 'model/h3/attn/c_attn/b',\n",
       " 'model/h3/attn/c_proj/w',\n",
       " 'model/h3/attn/c_proj/b',\n",
       " 'model/h3/ln_2/g',\n",
       " 'model/h3/ln_2/b',\n",
       " 'model/h3/mlp/c_fc/w',\n",
       " 'model/h3/mlp/c_fc/b',\n",
       " 'model/h3/mlp/c_proj/w',\n",
       " 'model/h3/mlp/c_proj/b',\n",
       " 'model/h4/ln_1/g',\n",
       " 'model/h4/ln_1/b',\n",
       " 'model/h4/attn/c_attn/w',\n",
       " 'model/h4/attn/c_attn/b',\n",
       " 'model/h4/attn/c_proj/w',\n",
       " 'model/h4/attn/c_proj/b',\n",
       " 'model/h4/ln_2/g',\n",
       " 'model/h4/ln_2/b',\n",
       " 'model/h4/mlp/c_fc/w',\n",
       " 'model/h4/mlp/c_fc/b',\n",
       " 'model/h4/mlp/c_proj/w',\n",
       " 'model/h4/mlp/c_proj/b',\n",
       " 'model/h5/ln_1/g',\n",
       " 'model/h5/ln_1/b',\n",
       " 'model/h5/attn/c_attn/w',\n",
       " 'model/h5/attn/c_attn/b',\n",
       " 'model/h5/attn/c_proj/w',\n",
       " 'model/h5/attn/c_proj/b',\n",
       " 'model/h5/ln_2/g',\n",
       " 'model/h5/ln_2/b',\n",
       " 'model/h5/mlp/c_fc/w',\n",
       " 'model/h5/mlp/c_fc/b',\n",
       " 'model/h5/mlp/c_proj/w',\n",
       " 'model/h5/mlp/c_proj/b',\n",
       " 'model/h6/ln_1/g',\n",
       " 'model/h6/ln_1/b',\n",
       " 'model/h6/attn/c_attn/w',\n",
       " 'model/h6/attn/c_attn/b',\n",
       " 'model/h6/attn/c_proj/w',\n",
       " 'model/h6/attn/c_proj/b',\n",
       " 'model/h6/ln_2/g',\n",
       " 'model/h6/ln_2/b',\n",
       " 'model/h6/mlp/c_fc/w',\n",
       " 'model/h6/mlp/c_fc/b',\n",
       " 'model/h6/mlp/c_proj/w',\n",
       " 'model/h6/mlp/c_proj/b',\n",
       " 'model/h7/ln_1/g',\n",
       " 'model/h7/ln_1/b',\n",
       " 'model/h7/attn/c_attn/w',\n",
       " 'model/h7/attn/c_attn/b',\n",
       " 'model/h7/attn/c_proj/w',\n",
       " 'model/h7/attn/c_proj/b',\n",
       " 'model/h7/ln_2/g',\n",
       " 'model/h7/ln_2/b',\n",
       " 'model/h7/mlp/c_fc/w',\n",
       " 'model/h7/mlp/c_fc/b',\n",
       " 'model/h7/mlp/c_proj/w',\n",
       " 'model/h7/mlp/c_proj/b',\n",
       " 'model/h8/ln_1/g',\n",
       " 'model/h8/ln_1/b',\n",
       " 'model/h8/attn/c_attn/w',\n",
       " 'model/h8/attn/c_attn/b',\n",
       " 'model/h8/attn/c_proj/w',\n",
       " 'model/h8/attn/c_proj/b',\n",
       " 'model/h8/ln_2/g',\n",
       " 'model/h8/ln_2/b',\n",
       " 'model/h8/mlp/c_fc/w',\n",
       " 'model/h8/mlp/c_fc/b',\n",
       " 'model/h8/mlp/c_proj/w',\n",
       " 'model/h8/mlp/c_proj/b',\n",
       " 'model/h9/ln_1/g',\n",
       " 'model/h9/ln_1/b',\n",
       " 'model/h9/attn/c_attn/w',\n",
       " 'model/h9/attn/c_attn/b',\n",
       " 'model/h9/attn/c_proj/w',\n",
       " 'model/h9/attn/c_proj/b',\n",
       " 'model/h9/ln_2/g',\n",
       " 'model/h9/ln_2/b',\n",
       " 'model/h9/mlp/c_fc/w',\n",
       " 'model/h9/mlp/c_fc/b',\n",
       " 'model/h9/mlp/c_proj/w',\n",
       " 'model/h9/mlp/c_proj/b',\n",
       " 'model/h10/ln_1/g',\n",
       " 'model/h10/ln_1/b',\n",
       " 'model/h10/attn/c_attn/w',\n",
       " 'model/h10/attn/c_attn/b',\n",
       " 'model/h10/attn/c_proj/w',\n",
       " 'model/h10/attn/c_proj/b',\n",
       " 'model/h10/ln_2/g',\n",
       " 'model/h10/ln_2/b',\n",
       " 'model/h10/mlp/c_fc/w',\n",
       " 'model/h10/mlp/c_fc/b',\n",
       " 'model/h10/mlp/c_proj/w',\n",
       " 'model/h10/mlp/c_proj/b',\n",
       " 'model/h11/ln_1/g',\n",
       " 'model/h11/ln_1/b',\n",
       " 'model/h11/attn/c_attn/w',\n",
       " 'model/h11/attn/c_attn/b',\n",
       " 'model/h11/attn/c_proj/w',\n",
       " 'model/h11/attn/c_proj/b',\n",
       " 'model/h11/ln_2/g',\n",
       " 'model/h11/ln_2/b',\n",
       " 'model/h11/mlp/c_fc/w',\n",
       " 'model/h11/mlp/c_fc/b',\n",
       " 'model/h11/mlp/c_proj/w',\n",
       " 'model/h11/mlp/c_proj/b',\n",
       " 'model/h12/ln_1/g',\n",
       " 'model/h12/ln_1/b',\n",
       " 'model/h12/attn/c_attn/w',\n",
       " 'model/h12/attn/c_attn/b',\n",
       " 'model/h12/attn/c_proj/w',\n",
       " 'model/h12/attn/c_proj/b',\n",
       " 'model/h12/ln_2/g',\n",
       " 'model/h12/ln_2/b',\n",
       " 'model/h12/mlp/c_fc/w',\n",
       " 'model/h12/mlp/c_fc/b',\n",
       " 'model/h12/mlp/c_proj/w',\n",
       " 'model/h12/mlp/c_proj/b',\n",
       " 'model/h13/ln_1/g',\n",
       " 'model/h13/ln_1/b',\n",
       " 'model/h13/attn/c_attn/w',\n",
       " 'model/h13/attn/c_attn/b',\n",
       " 'model/h13/attn/c_proj/w',\n",
       " 'model/h13/attn/c_proj/b',\n",
       " 'model/h13/ln_2/g',\n",
       " 'model/h13/ln_2/b',\n",
       " 'model/h13/mlp/c_fc/w',\n",
       " 'model/h13/mlp/c_fc/b',\n",
       " 'model/h13/mlp/c_proj/w',\n",
       " 'model/h13/mlp/c_proj/b',\n",
       " 'model/h14/ln_1/g',\n",
       " 'model/h14/ln_1/b',\n",
       " 'model/h14/attn/c_attn/w',\n",
       " 'model/h14/attn/c_attn/b',\n",
       " 'model/h14/attn/c_proj/w',\n",
       " 'model/h14/attn/c_proj/b',\n",
       " 'model/h14/ln_2/g',\n",
       " 'model/h14/ln_2/b',\n",
       " 'model/h14/mlp/c_fc/w',\n",
       " 'model/h14/mlp/c_fc/b',\n",
       " 'model/h14/mlp/c_proj/w',\n",
       " 'model/h14/mlp/c_proj/b',\n",
       " 'model/h15/ln_1/g',\n",
       " 'model/h15/ln_1/b',\n",
       " 'model/h15/attn/c_attn/w',\n",
       " 'model/h15/attn/c_attn/b',\n",
       " 'model/h15/attn/c_proj/w',\n",
       " 'model/h15/attn/c_proj/b',\n",
       " 'model/h15/ln_2/g',\n",
       " 'model/h15/ln_2/b',\n",
       " 'model/h15/mlp/c_fc/w',\n",
       " 'model/h15/mlp/c_fc/b',\n",
       " 'model/h15/mlp/c_proj/w',\n",
       " 'model/h15/mlp/c_proj/b',\n",
       " 'model/h16/ln_1/g',\n",
       " 'model/h16/ln_1/b',\n",
       " 'model/h16/attn/c_attn/w',\n",
       " 'model/h16/attn/c_attn/b',\n",
       " 'model/h16/attn/c_proj/w',\n",
       " 'model/h16/attn/c_proj/b',\n",
       " 'model/h16/ln_2/g',\n",
       " 'model/h16/ln_2/b',\n",
       " 'model/h16/mlp/c_fc/w',\n",
       " 'model/h16/mlp/c_fc/b',\n",
       " 'model/h16/mlp/c_proj/w',\n",
       " 'model/h16/mlp/c_proj/b',\n",
       " 'model/h17/ln_1/g',\n",
       " 'model/h17/ln_1/b',\n",
       " 'model/h17/attn/c_attn/w',\n",
       " 'model/h17/attn/c_attn/b',\n",
       " 'model/h17/attn/c_proj/w',\n",
       " 'model/h17/attn/c_proj/b',\n",
       " 'model/h17/ln_2/g',\n",
       " 'model/h17/ln_2/b',\n",
       " 'model/h17/mlp/c_fc/w',\n",
       " 'model/h17/mlp/c_fc/b',\n",
       " 'model/h17/mlp/c_proj/w',\n",
       " 'model/h17/mlp/c_proj/b',\n",
       " 'model/h18/ln_1/g',\n",
       " 'model/h18/ln_1/b',\n",
       " 'model/h18/attn/c_attn/w',\n",
       " 'model/h18/attn/c_attn/b',\n",
       " 'model/h18/attn/c_proj/w',\n",
       " 'model/h18/attn/c_proj/b',\n",
       " 'model/h18/ln_2/g',\n",
       " 'model/h18/ln_2/b',\n",
       " 'model/h18/mlp/c_fc/w',\n",
       " 'model/h18/mlp/c_fc/b',\n",
       " 'model/h18/mlp/c_proj/w',\n",
       " 'model/h18/mlp/c_proj/b',\n",
       " 'model/h19/ln_1/g',\n",
       " 'model/h19/ln_1/b',\n",
       " 'model/h19/attn/c_attn/w',\n",
       " 'model/h19/attn/c_attn/b',\n",
       " 'model/h19/attn/c_proj/w',\n",
       " 'model/h19/attn/c_proj/b',\n",
       " 'model/h19/ln_2/g',\n",
       " 'model/h19/ln_2/b',\n",
       " 'model/h19/mlp/c_fc/w',\n",
       " 'model/h19/mlp/c_fc/b',\n",
       " 'model/h19/mlp/c_proj/w',\n",
       " 'model/h19/mlp/c_proj/b',\n",
       " 'model/h20/ln_1/g',\n",
       " 'model/h20/ln_1/b',\n",
       " 'model/h20/attn/c_attn/w',\n",
       " 'model/h20/attn/c_attn/b',\n",
       " 'model/h20/attn/c_proj/w',\n",
       " 'model/h20/attn/c_proj/b',\n",
       " 'model/h20/ln_2/g',\n",
       " 'model/h20/ln_2/b',\n",
       " 'model/h20/mlp/c_fc/w',\n",
       " 'model/h20/mlp/c_fc/b',\n",
       " 'model/h20/mlp/c_proj/w',\n",
       " 'model/h20/mlp/c_proj/b',\n",
       " 'model/h21/ln_1/g',\n",
       " 'model/h21/ln_1/b',\n",
       " 'model/h21/attn/c_attn/w',\n",
       " 'model/h21/attn/c_attn/b',\n",
       " 'model/h21/attn/c_proj/w',\n",
       " 'model/h21/attn/c_proj/b',\n",
       " 'model/h21/ln_2/g',\n",
       " 'model/h21/ln_2/b',\n",
       " 'model/h21/mlp/c_fc/w',\n",
       " 'model/h21/mlp/c_fc/b',\n",
       " 'model/h21/mlp/c_proj/w',\n",
       " 'model/h21/mlp/c_proj/b',\n",
       " 'model/h22/ln_1/g',\n",
       " 'model/h22/ln_1/b',\n",
       " 'model/h22/attn/c_attn/w',\n",
       " 'model/h22/attn/c_attn/b',\n",
       " 'model/h22/attn/c_proj/w',\n",
       " 'model/h22/attn/c_proj/b',\n",
       " 'model/h22/ln_2/g',\n",
       " 'model/h22/ln_2/b',\n",
       " 'model/h22/mlp/c_fc/w',\n",
       " 'model/h22/mlp/c_fc/b',\n",
       " 'model/h22/mlp/c_proj/w',\n",
       " 'model/h22/mlp/c_proj/b',\n",
       " 'model/h23/ln_1/g',\n",
       " 'model/h23/ln_1/b',\n",
       " 'model/h23/attn/c_attn/w',\n",
       " 'model/h23/attn/c_attn/b',\n",
       " 'model/h23/attn/c_proj/w',\n",
       " 'model/h23/attn/c_proj/b',\n",
       " 'model/h23/ln_2/g',\n",
       " 'model/h23/ln_2/b',\n",
       " 'model/h23/mlp/c_fc/w',\n",
       " 'model/h23/mlp/c_fc/b',\n",
       " 'model/h23/mlp/c_proj/w',\n",
       " 'model/h23/mlp/c_proj/b',\n",
       " 'model/ln_f/g',\n",
       " 'model/ln_f/b',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/sort/axis',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/sort/Shape',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/sort/Shape/Switch',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/sort/strided_slice/stack',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/sort/strided_slice/stack_1',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/sort/strided_slice/stack_2',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/sort/strided_slice',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/sort/Rank',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/sort/TopKV2',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/Softmax',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/Cumsum/axis',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/Cumsum',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/Less',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/Less/Switch',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/ones_like/Shape',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/ones_like/Const',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/ones_like',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/mul/y',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/mul',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/Select',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/Min/reduction_indices',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/Min',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/Less_1',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/ones_like_1/Shape',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/ones_like_1/Const',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/ones_like_1',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/mul_1/y',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/mul_1',\n",
       " 'sample_sequence/while/cond_1/top_p_logits/Select_1',\n",
       " 'output']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'gather' in n.op.lower()\n",
    "        or 'X' in n.name\n",
    "        or 'temp' in n.name\n",
    "        or 'top_' in n.name\n",
    "        or 'maxlen' in n.name\n",
    "        or 'n_samples' in n.name\n",
    "        or 'output' in n.name)\n",
    "        and 'adam' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "        and 'Assign' not in n.name\n",
    "        and 'ReadVariableOp' not in n.name\n",
    "        and 'Gather' not in n.name\n",
    "    ]\n",
    ")\n",
    "strings.split(',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e458886c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "41f38c30",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from gpt2-345m/model.ckpt\n",
      "WARNING:tensorflow:From <ipython-input-12-9a7215a4e58a>:23: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.convert_variables_to_constants`\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/graph_util_impl.py:277: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "INFO:tensorflow:Froze 292 variables.\n",
      "INFO:tensorflow:Converted 292 variables to const ops.\n",
      "11680 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph('gpt2-345m', strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ffeac81e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_graph(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "                \n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "        \n",
    "    return graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5cbe978b",
   "metadata": {},
   "outputs": [],
   "source": [
    "g = load_graph('gpt2-345m/frozen_model.pb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0ec52de",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_nodes = ['X', 'temp', 'top_k', 'top_p', 'maxlen', 'n_samples']\n",
    "output_nodes = ['output']\n",
    "inputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in input_nodes}\n",
    "outputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in output_nodes}\n",
    "inputs, outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9893e8ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sess = tf.Session(graph = g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2ac02e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "o = test_sess.run(outputs['output'], feed_dict = {inputs['X']: [encoded],\n",
    "                                  inputs['temp']: 0.0,\n",
    "                                  inputs['top_k']: 40,\n",
    "                                  inputs['top_p']: 0.0,\n",
    "                                  inputs['maxlen']: 100,\n",
    "                                  inputs['n_samples']: 1})\n",
    "o.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d6b4455",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(enc_malay.decode(o[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d7623f5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.tools.graph_transforms import TransformGraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "3075100b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-17-ba353da273a5>:15: FastGFile.__init__ (from tensorflow.python.platform.gfile) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.gfile.GFile.\n"
     ]
    }
   ],
   "source": [
    "transforms = ['add_default_attributes',\n",
    "             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',\n",
    "             'fold_batch_norms',\n",
    "             'fold_old_batch_norms',\n",
    "             'quantize_weights(fallback_min=-10, fallback_max=10)',\n",
    "             'strip_unused_nodes',\n",
    "             'sort_by_execution_order']\n",
    "\n",
    "input_nodes = ['X', 'temp', 'top_k', 'top_p', 'maxlen', 'n_samples']\n",
    "output_nodes = ['output']\n",
    "\n",
    "pb = 'gpt2-345m/frozen_model.pb'\n",
    "\n",
    "input_graph_def = tf.GraphDef()\n",
    "with tf.gfile.FastGFile(pb, 'rb') as f:\n",
    "    input_graph_def.ParseFromString(f.read())\n",
    "\n",
    "transformed_graph_def = TransformGraph(input_graph_def, \n",
    "                                           input_nodes,\n",
    "                                           output_nodes, transforms)\n",
    "    \n",
    "with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:\n",
    "    f.write(transformed_graph_def.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d91b8bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "g = load_graph('gpt2-345m/frozen_model.pb.quantized')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f0b7feb",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_nodes = ['X', 'temp', 'top_k', 'top_p', 'maxlen', 'n_samples']\n",
    "output_nodes = ['output']\n",
    "inputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in input_nodes}\n",
    "outputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in output_nodes}\n",
    "inputs, outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea326f8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sess = tf.Session(graph = g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "638c68cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "o = test_sess.run(outputs['output'], feed_dict = {inputs['X']: [encoded],\n",
    "                                  inputs['temp']: 0.0,\n",
    "                                  inputs['top_k']: 40,\n",
    "                                  inputs['top_p']: 0.0,\n",
    "                                  inputs['maxlen']: 100,\n",
    "                                  inputs['n_samples']: 1})\n",
    "o.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37047931",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(enc_malay.decode(o[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "16f70d9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt2-117m/\n",
      "gpt2-117m/checkpoint\n",
      "gpt2-117m/frozen_model.pb.quantized\n",
      "gpt2-117m/model.ckpt.index\n",
      "gpt2-117m/model.ckpt.data-00000-of-00001\n",
      "gpt2-117m/frozen_model.pb\n",
      "gpt2-117m/model.ckpt.meta\n"
     ]
    }
   ],
   "source": [
    "!tar -zcvf gpt2-117m.tar.gz gpt2-117m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ccc85735",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt2-345m/\n",
      "gpt2-345m/checkpoint\n",
      "gpt2-345m/frozen_model.pb.quantized\n",
      "gpt2-345m/model.ckpt.index\n",
      "gpt2-345m/model.ckpt.data-00000-of-00001\n",
      "gpt2-345m/frozen_model.pb\n",
      "gpt2-345m/model.ckpt.meta\n"
     ]
    }
   ],
   "source": [
    "!tar -zcvf gpt2-345m.tar.gz gpt2-345m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d8c344e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from b2sdk.v1 import *\n",
    "info = InMemoryAccountInfo()\n",
    "b2_api = B2Api(info)\n",
    "b2_api.authorize_account(\"production\", application_key_id, application_key)\n",
    "file_info = {'how': 'good-file'}\n",
    "b2_bucket = b2_api.get_bucket_by_name('malaya-model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "1e38ac8a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FileVersionInfo('4_zcde33cc461767caf742c0b11_f213f274edc6f941d_d20210923_m093316_c000_v0001089_t0050', 'pretrained/gpt2-345m.tar.gz', 2928155236, 'application/octet-stream', 'none', {'how': 'good-file'}, 1632389596000, <EncryptionSetting(EncryptionMode.NONE, None, None)>, <LegalHold.UNSET: None>, FileRetentionSetting(None, None), 1632389596000, None, None, None, 'upload', <b2sdk.v1.api.B2Api object at 0x7f79ea6604e0>)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file = 'gpt2-345m.tar.gz'\n",
    "outPutname = 'pretrained/gpt2-345m.tar.gz'\n",
    "b2_bucket.upload_local_file(\n",
    "    local_file=file,\n",
    "    file_name=outPutname,\n",
    "    file_infos=file_info,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "dbd54a74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FileVersionInfo('4_zcde33cc461767caf742c0b11_f201775d542477dc5_d20210923_m093507_c000_v0001088_t0002', 'pretrained/gpt2-117m.tar.gz', 1026795428, 'application/octet-stream', 'none', {'how': 'good-file'}, 1632389707000, <EncryptionSetting(EncryptionMode.NONE, None, None)>, <LegalHold.UNSET: None>, FileRetentionSetting(None, None), 1632389707000, None, None, None, 'upload', <b2sdk.v1.api.B2Api object at 0x7f79ea6604e0>)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file = 'gpt2-117m.tar.gz'\n",
    "outPutname = 'pretrained/gpt2-117m.tar.gz'\n",
    "b2_bucket.upload_local_file(\n",
    "    local_file=file,\n",
    "    file_name=outPutname,\n",
    "    file_infos=file_info,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "34a9a955",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FileVersionInfo('4_zcde33cc461767caf742c0b11_f222b70c64e426192_d20210923_m093548_c000_v0001400_t0021', 'gpt2/345M/model.pb', 1421190629, 'application/octet-stream', 'none', {'how': 'good-file'}, 1632389748000, <EncryptionSetting(EncryptionMode.NONE, None, None)>, <LegalHold.UNSET: None>, FileRetentionSetting(None, None), 1632389748000, None, None, None, 'upload', <b2sdk.v1.api.B2Api object at 0x7f79ea6604e0>)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file = 'gpt2-345m/frozen_model.pb'\n",
    "outPutname = 'gpt2/345M/model.pb'\n",
    "b2_bucket.upload_local_file(\n",
    "    local_file=file,\n",
    "    file_name=outPutname,\n",
    "    file_infos=file_info,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "a35bfe41",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FileVersionInfo('4_zcde33cc461767caf742c0b11_f215e2ed8200f618a_d20210923_m093708_c000_v0001088_t0034', 'gpt2/345M-quantized/model.pb', 356782059, 'application/octet-stream', 'none', {'how': 'good-file'}, 1632389828000, <EncryptionSetting(EncryptionMode.NONE, None, None)>, <LegalHold.UNSET: None>, FileRetentionSetting(None, None), 1632389828000, None, None, None, 'upload', <b2sdk.v1.api.B2Api object at 0x7f79ea6604e0>)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file = 'gpt2-345m/frozen_model.pb.quantized'\n",
    "outPutname = 'gpt2/345M-quantized/model.pb'\n",
    "b2_bucket.upload_local_file(\n",
    "    local_file=file,\n",
    "    file_name=outPutname,\n",
    "    file_infos=file_info,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1ec42c4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
